{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Reformer: Text Generation",
      "provenance": [],
      "collapsed_sections": [
        "udDs_biH0n5U"
      ]
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "udDs_biH0n5U",
        "colab_type": "text"
      },
      "source": [
        "#### Copyright 2020 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WPY-OyyM0pSs",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Licensed under the Apache License, Version 2.0 (the \"License\")\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "\n",
        " https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "psnUF-8c02o_",
        "colab_type": "text"
      },
      "source": [
        "# Reformer: Text Generation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/text_generation.ipynb)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1lnRd_IoERdk",
        "colab_type": "text"
      },
      "source": [
        "This notebook was designed to run on TPU.\n",
        "\n",
        "To use TPUs in Colab, click \"Runtime\" on the main menu bar and select Change runtime type. Set \"TPU\" as the hardware accelerator."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8PluCmWbZIpJ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Install JAX.\n",
        "!gsutil cp gs://trax-ml/reformer/jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl .\n",
        "!gsutil cp gs://trax-ml/reformer/jax-0.1.59-cp36-none-manylinux2010_x86_64.whl .\n",
        "!pip install --upgrade -q ./jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl\n",
        "!pip install --upgrade -q ./jax-0.1.59-cp36-none-manylinux2010_x86_64.whl\n",
        "\n",
        "# Make sure the Colab Runtime is set to Accelerator: TPU.\n",
        "import requests\n",
        "import os\n",
        "if 'TPU_DRIVER_MODE' not in globals():\n",
        "  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'\n",
        "  resp = requests.post(url)\n",
        "  TPU_DRIVER_MODE = 1\n",
        "\n",
        "# The following is required to use TPU Driver as JAX's backend.\n",
        "from jax.config import config\n",
        "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n",
        "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n",
        "print(config.FLAGS.jax_backend_target)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yiPdBenoZwH6",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install --upgrade -q sentencepiece\n",
        "!pip install --upgrade -q gin git+https://github.com/google/trax.git@v1.2.3\n",
        "\n",
        "from tensorflow.compat.v1.io.gfile import GFile\n",
        "import gin\n",
        "import os\n",
        "import jax\n",
        "import trax\n",
        "from trax.models.beam_search import Search\n",
        "from trax.supervised import inputs\n",
        "\n",
        "import numpy as np\n",
        "import jax.numpy as jnp\n",
        "\n",
        "from scipy.special import softmax\n",
        "\n",
        "from sentencepiece import SentencePieceProcessor"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "FQ89jHCYfhpg"
      },
      "source": [
        "## Setting up data and model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9_OCIqghSyfs",
        "colab_type": "text"
      },
      "source": [
        "In this notebook, we'll be pushing the limits of just how many tokens we can fit on a single TPU device. The TPUs available in Colab have 8GB of memory per core, and 8 cores. We will set up a Reformer model that can fit a copy of \"Crime and Punishment\" on *each* of the 8 TPU cores (over 500,000 tokens per 8GB of memory)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tYSOVGR47LVL",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Import a copy of \"Crime and Punishment\", by Fyodor Dostoevsky\n",
        "with GFile('gs://trax-ml/reformer/crime-and-punishment-2554.txt') as f:\n",
        "  text = f.read()\n",
        "\n",
        "# The file read above includes metadata and licensing information.\n",
        "# For training our language model, we will only use the actual novel text.\n",
        "start = text.find('CRIME AND PUNISHMENT')  # skip header\n",
        "start = text.find('CRIME AND PUNISHMENT', start + 1)  # skip header\n",
        "start = text.find('CRIME AND PUNISHMENT', start + 1)  # skip translator preface\n",
        "end = text.rfind('End of Project')  # skip extra text at the end\n",
        "text = text[start:end].strip()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mMntV3H-6OR0",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Load a BPE vocabulaary with 320 types. This mostly consists of single letters\n",
        "# and pairs of letters, but it has some common words and word pieces, too.\n",
        "!gsutil cp gs://trax-ml/reformer/cp.320.* .\n",
        "\n",
        "TOKENIZER = SentencePieceProcessor()\n",
        "TOKENIZER.load('cp.320.model')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HnJzxSi_77zP",
        "colab_type": "code",
        "outputId": "ec510c06-5a49-42aa-ebde-585e487348b7",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "# Tokenize\n",
        "IDS = TOKENIZER.EncodeAsIds(text)\n",
        "IDS = np.asarray(IDS, dtype=np.int32)\n",
        "PAD_AMOUNT = 512 * 1024 - len(IDS)\n",
        "print(\"Number of tokens:\", IDS.shape[0])"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Number of tokens: 513812\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bzQ7G9uGSga5",
        "colab_type": "text"
      },
      "source": [
        "As we see above, \"Crime and Punishment\" has just over half a million tokens with the BPE vocabulary we have selected.\n",
        "\n",
        "Normally we would have a dataset with many examples, but for this demonstration we fit a language model on the single novel only. We don't want the model to just memorize the dataset by encoding the words in its position embeddings, so at each training iteration we will randomly select how much padding to put before the text vs. after it.\n",
        "\n",
        "We have 8 TPU cores, so we will separately randomize the amount of padding for each core."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PdAwmpS220ub",
        "colab_type": "code",
        "outputId": "ff1e17a9-f63d-4c02-ac19-877737a5673c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "# Set up the data pipeline.\n",
        "def my_inputs(n_devices):\n",
        "  while True:\n",
        "    inputs = []\n",
        "    mask = []\n",
        "    pad_amounts = np.random.choice(PAD_AMOUNT, n_devices)\n",
        "    for i in range(n_devices):\n",
        "      inputs.append(np.pad(IDS, (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),\n",
        "                            mode='constant'))\n",
        "      mask.append(np.pad(np.ones_like(IDS, dtype=np.float32),\n",
        "                          (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),\n",
        "                          mode='constant'))\n",
        "    inputs = np.stack(inputs)\n",
        "    mask = np.stack(mask)\n",
        "    yield (inputs, inputs, mask)\n",
        "\n",
        "print(\"(device count, tokens per device) = \",\n",
        "      next(my_inputs(trax.fastmath.device_count()))[0].shape)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "(device count, tokens per device) =  (8, 524288)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ei90LdK024r_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Configure hyperparameters.\n",
        "gin.parse_config(\"\"\"\n",
        "import trax.layers\n",
        "import trax.models\n",
        "import trax.optimizers\n",
        "import trax.supervised.inputs\n",
        "import trax.supervised.trainer_lib\n",
        "\n",
        "# Parameters that will vary between experiments:\n",
        "# ==============================================================================\n",
        "train.model = @trax.models.ReformerLM\n",
        "# Our model will have 6 layers, alternating between the LSH attention proposed\n",
        "# in the Reformer paper and local attention within a certain context window.\n",
        "n_layers = 6\n",
        "attn_type = [\n",
        "  @SelfAttention,\n",
        "  @LSHSelfAttention,  \n",
        "  @SelfAttention,\n",
        "  @LSHSelfAttention,\n",
        "  @SelfAttention,\n",
        "  @LSHSelfAttention,\n",
        "  ]\n",
        "share_qk = False  # LSH attention ignores this flag and always shares q & k\n",
        "n_heads = 2\n",
        "attn_kv = 64\n",
        "dropout = 0.05\n",
        "n_tokens = 524288\n",
        "\n",
        "# Parameters for MultifactorSchedule:\n",
        "# ==============================================================================\n",
        "MultifactorSchedule.constant = 0.01\n",
        "MultifactorSchedule.factors = 'constant * linear_warmup * cosine_decay'\n",
        "MultifactorSchedule.warmup_steps = 100\n",
        "MultifactorSchedule.steps_per_cycle = 900\n",
        "\n",
        "# Parameters for Adam:\n",
        "# ==============================================================================\n",
        "Adam.weight_decay_rate=0.0\n",
        "Adam.b1 = 0.86\n",
        "Adam.b2 = 0.92\n",
        "Adam.eps = 1e-9\n",
        "\n",
        "# Parameters for SelfAttention:\n",
        "# ==============================================================================\n",
        "SelfAttention.attention_dropout = 0.05\n",
        "SelfAttention.chunk_len = 64\n",
        "SelfAttention.n_chunks_before = 1\n",
        "SelfAttention.n_parallel_heads = 1\n",
        "\n",
        "# Parameters for LSHSelfAttention:\n",
        "# ==============================================================================\n",
        "LSHSelfAttention.attention_dropout = 0.0\n",
        "LSHSelfAttention.chunk_len = 64\n",
        "LSHSelfAttention.n_buckets = [64, 128]\n",
        "LSHSelfAttention.n_chunks_after = 0\n",
        "LSHSelfAttention.n_chunks_before = 1\n",
        "LSHSelfAttention.n_hashes = 1\n",
        "LSHSelfAttention.n_parallel_heads = 1\n",
        "LSHSelfAttention.predict_drop_len = 128\n",
        "LSHSelfAttention.predict_mem_len = 1024\n",
        "\n",
        "# Parameters for ReformerLM:\n",
        "# ==============================================================================\n",
        "ReformerLM.attention_type = %attn_type\n",
        "ReformerLM.d_attention_key = %attn_kv\n",
        "ReformerLM.d_attention_value = %attn_kv\n",
        "ReformerLM.d_model = 256\n",
        "ReformerLM.d_ff = 512\n",
        "ReformerLM.dropout = %dropout\n",
        "ReformerLM.ff_activation = @trax.layers.Relu\n",
        "ReformerLM.max_len = %n_tokens\n",
        "ReformerLM.mode = 'train'\n",
        "ReformerLM.n_heads = %n_heads\n",
        "ReformerLM.n_layers = %n_layers\n",
        "ReformerLM.vocab_size = 320\n",
        "ReformerLM.share_qk = %share_qk\n",
        "ReformerLM.axial_pos_shape = (512, 1024)\n",
        "ReformerLM.d_axial_pos_embs= (64, 192)\n",
        "\"\"\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RGGt0WaT3a-h",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Set up a Trainer.\n",
        "output_dir = os.path.expanduser('~/train_dir/')\n",
        "!rm -f ~/train_dir/model.pkl  # Remove old model\n",
        "trainer = trax.supervised.Trainer(\n",
        "    model=trax.models.ReformerLM,\n",
        "    loss_fn=trax.layers.CrossEntropyLoss,\n",
        "    optimizer=trax.optimizers.Adam,\n",
        "    lr_schedule=trax.lr.MultifactorSchedule,\n",
        "    inputs=trax.supervised.inputs.Inputs(my_inputs),\n",
        "    output_dir=output_dir,\n",
        "    has_weights=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "y6VQkmKO3a1L",
        "colab_type": "code",
        "outputId": "d5519372-44e9-4311-f84b-931b12e85232",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 269
        }
      },
      "source": [
        "# Run one training step, to make sure the model fits in memory.\n",
        "# The first time trainer.train_epoch is called, it will JIT the entire network\n",
        "# architecture, which takes around 2 minutes. The JIT-compiled model is saved\n",
        "# so subsequent runs will be much faster than the first.\n",
        "trainer.train_epoch(n_steps=1, n_eval_steps=1)"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\n",
            "Step      1: Ran 1 train steps in 161.13 secs\n",
            "Step      1: Evaluation\n",
            "Step      1: train                   accuracy |  0.00300037\n",
            "Step      1: train                       loss |  6.38712692\n",
            "Step      1: train         neg_log_perplexity |  6.38712692\n",
            "Step      1: train          sequence_accuracy |  0.00000000\n",
            "Step      1: train weights_per_batch_per_core |  513812.00000000\n",
            "Step      1: eval                    accuracy |  0.00302907\n",
            "Step      1: eval                        loss |  6.38727570\n",
            "Step      1: eval          neg_log_perplexity |  6.38727570\n",
            "Step      1: eval           sequence_accuracy |  0.00000000\n",
            "Step      1: eval  weights_per_batch_per_core |  513812.00000000\n",
            "Step      1: Finished evaluation\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EFnX4G6z3asD",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Train for 600 steps total\n",
        "# The first ~20 steps are slow to run, but after that it reaches steady-state\n",
        "# speed. This will take at least 30 minutes to run to completion, but can safely\n",
        "# be interrupted by selecting \"Runtime > Interrupt Execution\" from the menu.\n",
        "# The language model won't be exceptionally good when trained for just a few\n",
        "# steps and with minimal regularization. However, we can still sample from it to\n",
        "# see what it learns.\n",
        "trainer.train_epoch(n_steps=9, n_eval_steps=1)\n",
        "for _ in range(59):\n",
        "  trainer.train_epoch(n_steps=10, n_eval_steps=1)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zY3hpgnI5Rgn",
        "colab_type": "text"
      },
      "source": [
        "## Sample from the model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ffeLSbJk35pv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# As we report in the Reformer paper, increasing the number of hashing rounds\n",
        "# helps with quality. We can even increase the number of hashing rounds at\n",
        "# evaluation time only.\n",
        "gin.parse_config(\"\"\"LSHCausalAttention.n_hashes = 4\"\"\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Eq45QGXKG3UG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Construct the decoder instance.\n",
        "# Unfortunately this code ends up leaking some memory, so we can only set up the\n",
        "# decoder once before the memory leak prevents us from running the model and we\n",
        "# have to restart the notebook.\n",
        "sampling_decoder = Search(\n",
        "    trax.models.ReformerLM,\n",
        "    trainer.model_weights,\n",
        "    temperature=1.0,\n",
        "    max_decode_len=128,\n",
        "    )"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dfeXilrHHJ6P",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Sample from the Reformer language model.\n",
        "seqs, scores = sampling_decoder.decode(batch_size=1)\n",
        "sample = seqs[0, -1]\n",
        "TOKENIZER.DecodeIds(sample.tolist())"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "o31Wtxuu5Ehf",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}
